{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# tf.TFRecord API使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)\n",
      "matplotlib 3.1.2\n",
      "numpy 1.18.1\n",
      "pandas 0.25.3\n",
      "sklearn 0.22.1\n",
      "tensorflow 2.0.0\n",
      "tensorflow_core.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "# 导入\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tfrecord 基础API使用\n",
    "\n",
    "tfrecord 是一种文件格式\n",
    "\n",
    "-> tf.train.Example\n",
    "\n",
    "    -> tf.train.Features -> {\"key\": tf.train.Feature}\n",
    "   \n",
    "        -> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "value: \"machine learning\"\n",
      "value: \"cc150\"\n",
      "\n",
      "value: 15.5\n",
      "value: 9.5\n",
      "value: 7.0\n",
      "value: 8.0\n",
      "\n",
      "value: 42\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# tf.train.ByteList/FloatList/Int64List\n",
    "favorite_books = [name.encode('utf-8') for name in ['machine learning', 'cc150']]\n",
    "favorite_books_bytelist = tf.train.BytesList(value = favorite_books)\n",
    "print(favorite_books_bytelist)\n",
    "\n",
    "hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])\n",
    "print(hours_floatlist)\n",
    "\n",
    "age_int64list = tf.train.Int64List(value = [42])\n",
    "print(age_int64list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "feature {\n",
      "  key: \"age\"\n",
      "  value {\n",
      "    int64_list {\n",
      "      value: 42\n",
      "    }\n",
      "  }\n",
      "}\n",
      "feature {\n",
      "  key: \"favorite_books\"\n",
      "  value {\n",
      "    bytes_list {\n",
      "      value: \"machine learning\"\n",
      "      value: \"cc150\"\n",
      "    }\n",
      "  }\n",
      "}\n",
      "feature {\n",
      "  key: \"hours\"\n",
      "  value {\n",
      "    float_list {\n",
      "      value: 15.5\n",
      "      value: 9.5\n",
      "      value: 7.0\n",
      "      value: 8.0\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# tf.train.Features and tf.train.Feature\n",
    "features = tf.train.Features(\n",
    "    feature = {\n",
    "        'favorite_books': tf.train.Feature(bytes_list = favorite_books_bytelist),\n",
    "        'hours': tf.train.Feature(float_list = hours_floatlist),\n",
    "        'age': tf.train.Feature(int64_list = age_int64list)\n",
    "    }\n",
    ")\n",
    "print(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features {\n",
      "  feature {\n",
      "    key: \"age\"\n",
      "    value {\n",
      "      int64_list {\n",
      "        value: 42\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  feature {\n",
      "    key: \"favorite_books\"\n",
      "    value {\n",
      "      bytes_list {\n",
      "        value: \"machine learning\"\n",
      "        value: \"cc150\"\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  feature {\n",
      "    key: \"hours\"\n",
      "    value {\n",
      "      float_list {\n",
      "        value: 15.5\n",
      "        value: 9.5\n",
      "        value: 7.0\n",
      "        value: 8.0\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n",
      "b'\\n\\\\\\n-\\n\\x0efavorite_books\\x12\\x1b\\n\\x19\\n\\x10machine learning\\n\\x05cc150\\n\\x0c\\n\\x03age\\x12\\x05\\x1a\\x03\\n\\x01*\\n\\x1d\\n\\x05hours\\x12\\x14\\x12\\x12\\n\\x10\\x00\\x00xA\\x00\\x00\\x18A\\x00\\x00\\xe0@\\x00\\x00\\x00A'\n"
     ]
    }
   ],
   "source": [
    "# tf.train.Example\n",
    "example = tf.train.Example(features=features)\n",
    "print(example)\n",
    "\n",
    "# 序列化，压缩大小\n",
    "serialized_example = example.SerializeToString()\n",
    "print(serialized_example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 对tfrecord进行保存、读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = os.path.join('tfrecord_basic')\n",
    "if not  os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "filename = 'test.tfrecords'\n",
    "filename_fullpath = os.path.join(output_dir, filename)\n",
    "\n",
    "# 保存example\n",
    "with tf.io.TFRecordWriter(filename_fullpath) as writer:\n",
    "    for i in range(3):\n",
    "        writer.write(serialized_example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'\\n\\\\\\n-\\n\\x0efavorite_books\\x12\\x1b\\n\\x19\\n\\x10machine learning\\n\\x05cc150\\n\\x0c\\n\\x03age\\x12\\x05\\x1a\\x03\\n\\x01*\\n\\x1d\\n\\x05hours\\x12\\x14\\x12\\x12\\n\\x10\\x00\\x00xA\\x00\\x00\\x18A\\x00\\x00\\xe0@\\x00\\x00\\x00A', shape=(), dtype=string)\n",
      "tf.Tensor(b'\\n\\\\\\n-\\n\\x0efavorite_books\\x12\\x1b\\n\\x19\\n\\x10machine learning\\n\\x05cc150\\n\\x0c\\n\\x03age\\x12\\x05\\x1a\\x03\\n\\x01*\\n\\x1d\\n\\x05hours\\x12\\x14\\x12\\x12\\n\\x10\\x00\\x00xA\\x00\\x00\\x18A\\x00\\x00\\xe0@\\x00\\x00\\x00A', shape=(), dtype=string)\n",
      "tf.Tensor(b'\\n\\\\\\n-\\n\\x0efavorite_books\\x12\\x1b\\n\\x19\\n\\x10machine learning\\n\\x05cc150\\n\\x0c\\n\\x03age\\x12\\x05\\x1a\\x03\\n\\x01*\\n\\x1d\\n\\x05hours\\x12\\x14\\x12\\x12\\n\\x10\\x00\\x00xA\\x00\\x00\\x18A\\x00\\x00\\xe0@\\x00\\x00\\x00A', shape=(), dtype=string)\n"
     ]
    }
   ],
   "source": [
    "# 读取\n",
    "dataset = tf.data.TFRecordDataset([filename_fullpath])\n",
    "for serialized_example_tensor in dataset:\n",
    "    print(serialized_example_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA01D36D8>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022F8872D908>, 'age': <tf.Tensor: id=46, shape=(), dtype=int64, numpy=42>}\n",
      "machine learning\n",
      "cc150\n",
      "{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA961AA90>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA9614AC8>, 'age': <tf.Tensor: id=65, shape=(), dtype=int64, numpy=42>}\n",
      "machine learning\n",
      "cc150\n",
      "{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA01D35F8>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022F88717278>, 'age': <tf.Tensor: id=84, shape=(), dtype=int64, numpy=42>}\n",
      "machine learning\n",
      "cc150\n"
     ]
    }
   ],
   "source": [
    "# 解析后读取\n",
    "expected_features = {\n",
    "    'favorite_books': tf.io.VarLenFeature(dtype=tf.string),\n",
    "    'hours': tf.io.VarLenFeature(dtype=tf.float32),\n",
    "    'age': tf.io.FixedLenFeature([], dtype=tf.int64),\n",
    "}\n",
    "\n",
    "dataset = tf.data.TFRecordDataset([filename_fullpath])\n",
    "for serialized_example_tensor in dataset:\n",
    "    example = tf.io.parse_single_example(serialized_example_tensor, expected_features)\n",
    "    print(example)\n",
    "    books = tf.sparse.to_dense(example['favorite_books'], default_value=b'')\n",
    "    for book in books:\n",
    "        print(book.numpy().decode('UTF-8'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 压缩zip保存example\n",
    "filename_fullpath_zip = filename_fullpath + '.zip'\n",
    "options = tf.io.TFRecordOptions(compression_type='GZIP')\n",
    "with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:\n",
    "    for i in range(3):\n",
    "        writer.write(serialized_example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA9614DD8>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA9640240>, 'age': <tf.Tensor: id=120, shape=(), dtype=int64, numpy=42>}\n",
      "machine learning\n",
      "cc150\n",
      "{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022F88717F98>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA01D3F28>, 'age': <tf.Tensor: id=139, shape=(), dtype=int64, numpy=42>}\n",
      "machine learning\n",
      "cc150\n",
      "{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA961A8D0>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000022FA9640438>, 'age': <tf.Tensor: id=158, shape=(), dtype=int64, numpy=42>}\n",
      "machine learning\n",
      "cc150\n"
     ]
    }
   ],
   "source": [
    "# 对压缩后的tfrecord进行读取\n",
    "dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip], compression_type='GZIP')\n",
    "for serialized_example_tensor in dataset_zip:\n",
    "    example = tf.io.parse_single_example(serialized_example_tensor, expected_features)\n",
    "    print(example)\n",
    "    books = tf.sparse.to_dense(example['favorite_books'], default_value=b'')\n",
    "    for book in books:\n",
    "        print(book.numpy().decode('UTF-8'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成tfrecords文件\n",
    "将csv文件进行转化生成tfrecord文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['test_00.csv', 'test_01.csv', 'test_02.csv', 'test_03.csv', 'test_04.csv', 'test_05.csv', 'test_06.csv', 'test_07.csv', 'test_08.csv', 'test_09.csv', 'train_00.csv', 'train_01.csv', 'train_02.csv', 'train_03.csv', 'train_04.csv', 'train_05.csv', 'train_06.csv', 'train_07.csv', 'train_08.csv', 'train_09.csv', 'train_10.csv', 'train_11.csv', 'train_12.csv', 'train_13.csv', 'train_14.csv', 'train_15.csv', 'train_16.csv', 'train_17.csv', 'train_18.csv', 'train_19.csv', 'valid_00.csv', 'valid_01.csv', 'valid_02.csv', 'valid_03.csv', 'valid_04.csv', 'valid_05.csv', 'valid_06.csv', 'valid_07.csv', 'valid_08.csv', 'valid_09.csv']\n"
     ]
    }
   ],
   "source": [
    "source_dir = 'generate_csv'\n",
    "print(os.listdir(source_dir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['generate_csv\\\\train_00.csv',\n",
      " 'generate_csv\\\\train_01.csv',\n",
      " 'generate_csv\\\\train_02.csv',\n",
      " 'generate_csv\\\\train_03.csv',\n",
      " 'generate_csv\\\\train_04.csv',\n",
      " 'generate_csv\\\\train_05.csv',\n",
      " 'generate_csv\\\\train_06.csv',\n",
      " 'generate_csv\\\\train_07.csv',\n",
      " 'generate_csv\\\\train_08.csv',\n",
      " 'generate_csv\\\\train_09.csv',\n",
      " 'generate_csv\\\\train_10.csv',\n",
      " 'generate_csv\\\\train_11.csv',\n",
      " 'generate_csv\\\\train_12.csv',\n",
      " 'generate_csv\\\\train_13.csv',\n",
      " 'generate_csv\\\\train_14.csv',\n",
      " 'generate_csv\\\\train_15.csv',\n",
      " 'generate_csv\\\\train_16.csv',\n",
      " 'generate_csv\\\\train_17.csv',\n",
      " 'generate_csv\\\\train_18.csv',\n",
      " 'generate_csv\\\\train_19.csv']\n",
      "['generate_csv\\\\valid_00.csv',\n",
      " 'generate_csv\\\\valid_01.csv',\n",
      " 'generate_csv\\\\valid_02.csv',\n",
      " 'generate_csv\\\\valid_03.csv',\n",
      " 'generate_csv\\\\valid_04.csv',\n",
      " 'generate_csv\\\\valid_05.csv',\n",
      " 'generate_csv\\\\valid_06.csv',\n",
      " 'generate_csv\\\\valid_07.csv',\n",
      " 'generate_csv\\\\valid_08.csv',\n",
      " 'generate_csv\\\\valid_09.csv']\n",
      "['generate_csv\\\\test_00.csv',\n",
      " 'generate_csv\\\\test_01.csv',\n",
      " 'generate_csv\\\\test_02.csv',\n",
      " 'generate_csv\\\\test_03.csv',\n",
      " 'generate_csv\\\\test_04.csv',\n",
      " 'generate_csv\\\\test_05.csv',\n",
      " 'generate_csv\\\\test_06.csv',\n",
      " 'generate_csv\\\\test_07.csv',\n",
      " 'generate_csv\\\\test_08.csv',\n",
      " 'generate_csv\\\\test_09.csv']\n"
     ]
    }
   ],
   "source": [
    "# 根据前缀划分文件\n",
    "def get_filename_by_prefix(source_dir, prefix_name):\n",
    "    all_files = os.listdir(source_dir)\n",
    "    results = []\n",
    "    for filename in all_files:\n",
    "        if filename.startswith(prefix_name):\n",
    "            results.append(os.path.join(source_dir, filename))\n",
    "    return results\n",
    "\n",
    "train_filenames = get_filename_by_prefix(source_dir, 'train')\n",
    "valid_filenames = get_filename_by_prefix(source_dir, 'valid')\n",
    "test_filenames = get_filename_by_prefix(source_dir, 'test')\n",
    "\n",
    "import pprint\n",
    "pprint.pprint(train_filenames)\n",
    "pprint.pprint(valid_filenames)\n",
    "pprint.pprint(test_filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取csv文件\n",
    "def parse_csv_line(line, n_fields=9):\n",
    "    defs = [tf.constant(np.nan)] * n_fields\n",
    "    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)\n",
    "    x = tf.stack(parsed_fields[0:-1]) # 将前n-1条数据转化为x向量\n",
    "    y = tf.stack(parsed_fields[-1:]) # 将最后一条转化为y向量\n",
    "    return x, y\n",
    "\n",
    "def csv_reader_dateset(filenames, n_readers=5, batch_size=32,\n",
    "                       n_parse_threads=5, shuffle_buffer_size=10000):\n",
    "    # 1.filename -> filename dataset\n",
    "    dataset = tf.data.Dataset.list_files(filenames)\n",
    "    dataset = dataset.repeat() # 将数据重复无限次\n",
    "    # 2.filename dataset -> text dataset 一对多的关系\n",
    "    dataset = dataset.interleave(\n",
    "        lambda filename : tf.data.TextLineDataset(filename).skip(1),\n",
    "        cycle_length=n_readers\n",
    "    )\n",
    "    dataset.shuffle(shuffle_buffer_size) # 将数据进行混排\n",
    "    # 3.parse csv 一对一关系\n",
    "    dataset = dataset.map(parse_csv_line, num_parallel_calls=n_parse_threads)\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    return dataset\n",
    "\n",
    "batch_size = 32\n",
    "train_set = csv_reader_dateset(train_filenames, batch_size=batch_size)\n",
    "valid_set = csv_reader_dateset(valid_filenames, batch_size=batch_size)\n",
    "test_set = csv_reader_dateset(test_filenames, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将csv读取出来的dataset转化为tfrecord格式\n",
    "def serialize_example(x, y):\n",
    "    \"\"\"Converts x,y to tf.train.Example and serialize\"\"\"\n",
    "    input_features = tf.train.FloatList(value = x)\n",
    "    lable = tf.train.FloatList(value = y)\n",
    "    features = tf.train.Features(\n",
    "        feature = {\n",
    "            'input_features': tf.train.Feature(float_list = input_features),\n",
    "            'lable': tf.train.Feature(float_list = lable)\n",
    "        }\n",
    "    )\n",
    "    example = tf.train.Example(features = features)\n",
    "    return example.SerializeToString()\n",
    "\n",
    "def csv_dataset_to_tfrecord(base_filename, dataset, n_shards,\n",
    "                            steps_per_shard, compression_type = None):\n",
    "    options = tf.io.TFRecordOptions(compression_type=compression_type)\n",
    "    all_filenames = []\n",
    "    \n",
    "    for shard_id in range(n_shards):\n",
    "        filename_fullpath = '{}_{:05d}-of-{:05d}'.format(\n",
    "            base_filename, shard_id, n_shards)\n",
    "        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:\n",
    "            # 从当前样本取出前steps_per_shard个\n",
    "            for x_batch, y_batch in dataset.take(steps_per_shard): \n",
    "                # 对获得的每个batch进行解包\n",
    "                for x_example, y_example in zip(x_batch, y_batch): \n",
    "                    # 封装成序列化的example，写到文件中\n",
    "                    writer.write(serialize_example(x_example, y_example))\n",
    "        all_filenames.append(filename_fullpath)\n",
    "    return all_filenames\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_shards = 20\n",
    "train_steps_per_shard = 11610 // batch_size // n_shards\n",
    "valid_steps_per_shard = 3880 // batch_size // n_shards\n",
    "test_steps_per_shard = 5170 // batch_size // n_shards\n",
    "\n",
    "output_dir = 'generate_tfrecords'\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "    \n",
    "train_basename = os.path.join(output_dir, 'train')\n",
    "valid_basename = os.path.join(output_dir, 'valid')\n",
    "test_basename = os.path.join(output_dir, 'test')\n",
    "\n",
    "train_tfrecord_filenames = csv_dataset_to_tfrecord(\n",
    "    train_basename, train_set, n_shards, train_steps_per_shard, None)\n",
    "valid_tfrecord_filenames = csv_dataset_to_tfrecord(\n",
    "    valid_basename, valid_set, n_shards, valid_steps_per_shard, None)\n",
    "test_tfrecord_filenames = csv_dataset_to_tfrecord(\n",
    "    test_basename, test_set, n_shards, test_steps_per_shard, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存成压缩后的\n",
    "n_shards = 20\n",
    "train_steps_per_shard = 11610 // batch_size // n_shards\n",
    "valid_steps_per_shard = 3880 // batch_size // n_shards\n",
    "test_steps_per_shard = 5170 // batch_size // n_shards\n",
    "\n",
    "output_dir = 'generate_tfrecords_zip'\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "    \n",
    "train_basename = os.path.join(output_dir, 'train')\n",
    "valid_basename = os.path.join(output_dir, 'valid')\n",
    "test_basename = os.path.join(output_dir, 'test')\n",
    "\n",
    "train_tfrecord_filenames = csv_dataset_to_tfrecord(\n",
    "    train_basename, train_set, n_shards, train_steps_per_shard, compression_type='GZIP')\n",
    "valid_tfrecord_filenames = csv_dataset_to_tfrecord(\n",
    "    valid_basename, valid_set, n_shards, valid_steps_per_shard, compression_type='GZIP')\n",
    "test_tfrecord_filenames = csv_dataset_to_tfrecord(\n",
    "    test_basename, test_set, n_shards, test_steps_per_shard, compression_type='GZIP')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取tfrecord文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['generate_tfrecords_zip\\\\train_00000-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00001-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00002-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00003-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00004-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00005-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00006-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00007-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00008-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00009-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00010-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00011-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00012-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00013-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00014-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00015-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00016-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00017-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00018-of-00020',\n",
      " 'generate_tfrecords_zip\\\\train_00019-of-00020']\n",
      "['generate_tfrecords_zip\\\\valid_00000-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00001-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00002-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00003-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00004-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00005-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00006-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00007-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00008-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00009-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00010-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00011-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00012-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00013-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00014-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00015-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00016-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00017-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00018-of-00020',\n",
      " 'generate_tfrecords_zip\\\\valid_00019-of-00020']\n",
      "['generate_tfrecords_zip\\\\test_00000-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00001-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00002-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00003-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00004-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00005-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00006-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00007-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00008-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00009-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00010-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00011-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00012-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00013-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00014-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00015-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00016-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00017-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00018-of-00020',\n",
      " 'generate_tfrecords_zip\\\\test_00019-of-00020']\n"
     ]
    }
   ],
   "source": [
    "pprint.pprint(train_tfrecord_filenames)\n",
    "pprint.pprint(valid_tfrecord_filenames)\n",
    "pprint.pprint(test_tfrecord_filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义期望的特征\n",
    "expected_features = {\n",
    "    'input_features': tf.io.FixedLenFeature([8], dtype=tf.float32),\n",
    "    'lable': tf.io.FixedLenFeature([1], dtype=tf.float32)\n",
    "}\n",
    "\n",
    "# 解析example\n",
    "def parse_example(serialize_example):\n",
    "    example = tf.io.parse_single_example(serialize_example, expected_features)\n",
    "    return example['input_features'], example['lable']\n",
    "\n",
    "# 从文件列表到dataset的转变\n",
    "def tfrecords_reader_dateset(filenames, n_readers=5, batch_size=32,\n",
    "                             n_parse_threads=5, shuffle_buffer_size=10000):\n",
    "    # 1.filename -> filename dataset\n",
    "    dataset = tf.data.Dataset.list_files(filenames)\n",
    "    dataset = dataset.repeat() # 将数据重复无限次\n",
    "    # 2.filename dataset -> tfrecord dataset 一对多的关系\n",
    "    dataset = dataset.interleave(\n",
    "        lambda filename : tf.data.TFRecordDataset(filename, compression_type='GZIP'),\n",
    "        cycle_length=n_readers\n",
    "    )\n",
    "    dataset.shuffle(shuffle_buffer_size) # 将数据进行混排\n",
    "    # 3.parse csv 一对一关系\n",
    "    dataset = dataset.map(parse_example, num_parallel_calls=n_parse_threads)\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ 0.4240821   0.91296333 -0.04437482 -0.15297213 -0.24727628 -0.10539167\n",
      "   0.86126745 -1.335779  ]\n",
      " [ 0.04326301 -1.0895426  -0.38878718 -0.10789865 -0.68186635 -0.0723871\n",
      "  -0.8883662   0.8213992 ]\n",
      " [ 0.04326301 -1.0895426  -0.38878718 -0.10789865 -0.68186635 -0.0723871\n",
      "  -0.8883662   0.8213992 ]], shape=(3, 8), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[3.955]\n",
      " [1.426]\n",
      " [1.426]], shape=(3, 1), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[-1.0591781   1.3935647  -0.02633197 -0.1100676  -0.6138199  -0.09695935\n",
      "   0.3247131  -0.03747724]\n",
      " [ 0.8015443   0.27216142 -0.11624393 -0.20231152 -0.5430516  -0.02103962\n",
      "  -0.5897621  -0.08241846]\n",
      " [-1.0775077  -0.4487407  -0.5680568  -0.14269263 -0.09666677  0.12326469\n",
      "  -0.31448638 -0.4818959 ]], shape=(3, 8), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.672]\n",
      " [3.226]\n",
      " [0.978]], shape=(3, 1), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# test\n",
    "tfrecords_train = tfrecords_reader_dateset(train_tfrecord_filenames, batch_size=3)\n",
    "for x_batch, y_batch in tfrecords_train.take(2):\n",
    "    print(x_batch)\n",
    "    print(y_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成训练中dataset\n",
    "batch_size = 32\n",
    "tfrecords_train_set = tfrecords_reader_dateset(\n",
    "    train_tfrecord_filenames, batch_size=batch_size)\n",
    "tfrecords_valid_set = tfrecords_reader_dateset(\n",
    "    valid_tfrecord_filenames, batch_size=batch_size)\n",
    "tfrecords_test_set = tfrecords_reader_dateset(\n",
    "    test_tfrecord_filenames, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 在keras中使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense (Dense)                (None, 30)                270       \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 31        \n",
      "=================================================================\n",
      "Total params: 301\n",
      "Trainable params: 301\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Train for 348 steps, validate for 120 steps\n",
      "Epoch 1/100\n",
      "348/348 [==============================] - 2s 6ms/step - loss: 2.6051 - val_loss: 0.8830\n",
      "Epoch 2/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.6831 - val_loss: 0.7357\n",
      "Epoch 3/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.5092 - val_loss: 0.5638\n",
      "Epoch 4/100\n",
      "348/348 [==============================] - 1s 3ms/step - loss: 0.4318 - val_loss: 0.4799\n",
      "Epoch 5/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3877 - val_loss: 0.4516\n",
      "Epoch 6/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3853 - val_loss: 0.4359\n",
      "Epoch 7/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3710 - val_loss: 0.4284\n",
      "Epoch 8/100\n",
      "348/348 [==============================] - 1s 3ms/step - loss: 0.3521 - val_loss: 0.4193\n",
      "Epoch 9/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3508 - val_loss: 0.4103\n",
      "Epoch 10/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3565 - val_loss: 0.4111\n",
      "Epoch 11/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3298 - val_loss: 0.3984\n",
      "Epoch 12/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3447 - val_loss: 0.4037\n",
      "Epoch 13/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3249 - val_loss: 0.3919\n",
      "Epoch 14/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3255 - val_loss: 0.3910\n",
      "Epoch 15/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3159 - val_loss: 0.3832\n",
      "Epoch 16/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3216 - val_loss: 0.3861\n",
      "Epoch 17/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3090 - val_loss: 0.3857\n",
      "Epoch 18/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3090 - val_loss: 0.3815\n",
      "Epoch 19/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3067 - val_loss: 0.3767\n",
      "Epoch 20/100\n",
      "348/348 [==============================] - 1s 4ms/step - loss: 0.3079 - val_loss: 0.3754\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation='relu',input_shape=[8]),\n",
    "    keras.layers.Dense(1)\n",
    "])\n",
    "model.summary()\n",
    "\n",
    "model.compile(loss='mean_squared_error',optimizer='adam')\n",
    "\n",
    "callbacks = [keras.callbacks.EarlyStopping(patience=5,min_delta=1e-2)]\n",
    "history = model.fit(tfrecords_train_set,\n",
    "                    validation_data = tfrecords_valid_set,\n",
    "                    steps_per_epoch = 11160 // batch_size,\n",
    "                    validation_steps = 3870 // batch_size,\n",
    "                    epochs = 100,\n",
    "                    callbacks = callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "161/161 [==============================] - 0s 2ms/step - loss: 0.3255\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.3255322992524005"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
